{-# LANGUAGE CPP #-}
-- for unboxed shifts
-----------------------------------------------------------------------------
-- |
-- Module      : LLVM.Parser.BitGet
-- Copyright   : Mikhail Belyaev
-- License     : GPLv3 (see LICENSE)
--
-- Forked version of Binary.Strict.BitGet by Adam Langley
-- with methods bitsParsed and bytesParsed added
--
-- This is a reader monad for parsing bit-aligned data. The usual Get monad
-- handles byte aligned data well.
--
-- In this monad, the current offset into the input is a number of bits, and
-- fetching n bits from the current position will shift everything correctly.
-- Bit vectors are represented as ByteStrings here either the first @n@ bits
-- are valid (left aligned) or the last @n@ bits are (right aligned).
--
-- If one is looking to parse integers etc, right alignment is the easist to
-- work with, however left alignment makes more sense in some situations.
-----------------------------------------------------------------------------

-- TODO: Maybe a strictness question will rise - this monad should really be lazy
-- (by that I mean working with ByteString.Lazy, of course)


-----------------------------------------------------------------------------
-- Modification 21.10.2010 -- changed to using Control.Monad.State as base
-- Modification nn.nn.2010 -- changed to using Error-based monad to control errors
-----------------------------------------------------------------------------
module LLVM.Parser.Bits.BitGet (
  -- * Get @BitGet@ type
    BitGet
  , ErrBitGet
  , unliftError
  , failsafe
  , runBitGet

  -- * Utility
  , skip
  , remaining
  , bytesParsed
  , bitsParsed
  , isEmpty
  , lookAhead

  -- * Generic parsing
  , getBit
  , getLeftByteString
  , getRightByteString

  -- ** Interpreting some number of bits as an integer
  , getAsWord8
  , getAsWord16
  , getAsWord32
  , getAsWord64

  -- ** Parsing particular types
  , getWord8
  , getWord16le
  , getWord16be
  , getWord16host
  , getWord32le
  , getWord32be
  , getWord32host
  , getWord64le
  , getWord64be
  , getWord64host
  , getWordhost
) where

#include "Common.h"

import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as BI
import Data.Binary.Strict.BitUtil
import Foreign
import Data.Bits

import Control.Monad.State
import Control.Monad.Error

import Data.Word

import LLVM.Util

data S = S {
             input       :: !B.ByteString, -- input
             bitOffset   :: !Word8,        -- bit offset in current byte
             totalRead   :: !Word64        -- total bytes read
           }

-- BitGet monad, should not be used in case of possible errors
type BitGet = State S
-- ErrBitGet monad, error-stacked version of BitGet
type ErrBitGet = ErrorOf BitGet

-- | Run a BitGet on a ByteString
runBitGet :: B.ByteString -> BitGet a -> a
runBitGet input m = evalState m $ S input 0 0

-- | Same as the standard splitAt, but in this version both parts share a byte
--   so that splitting [1,2,3,4] at 2 results in ([1,2], [2, 3, 4]).
splitAtWithDupByte :: Int -> B.ByteString -> (B.ByteString, B.ByteString)
splitAtWithDupByte n bs = (B.take n bs, B.drop (n - 1) bs)

-- | Used as a flag argument to readN to control weather the resulting
--   ByteString is left or right aligned
data Direction = BLeft | BRight deriving (Show)

apply :: Int -> Int -> B.ByteString -> B.ByteString
apply v n bs =  rightTruncateBits n $ rightShift v $ B.reverse bs

-- | Fetch some number of bits from the input and return them as a ByteString
--   after applying the given function
-- this function can go wrong !!
readN :: Direction -> Int -> (B.ByteString -> a) -> ErrBitGet a
readN d n f = do
--  bp <- lift $ bytesParsed
--  mdebug $ show bp
  S bytes boff proced <- get
  let bitsRemaining = B.length bytes * 8 - boffInt
      boffInt = fromIntegral boff
      (shiftFunction, truncateFunction) =
        case d of
             BLeft -> (leftShift, leftTruncateBits)
             BRight -> (\off -> rightShift $ (((8 - (n `mod` 8)) `mod` 8) - off) `mod` 8,
                        rightTruncateBits)
  if bitsRemaining < n
     then throwError "Too few bits remain"
     else do let bytesRequired = ((n - 1 + boffInt) `div` 8) + 1 -- (n `div` 8) + (if boffInt + (n `mod` 8) > 0 then 1 else 0)
                 qboff = (boffInt + n) `mod` 8
                 qproced = proced + (fromIntegral (boffInt + n) `div` 8)
             let (r, rest) = if qboff == 0
                                then B.splitAt bytesRequired bytes
                                else splitAtWithDupByte bytesRequired bytes
             put $ S rest (fromIntegral qboff) (fromIntegral qproced)
             -- return $ f $ truncateFunction n $ shiftFunction boffInt r
             return $ f $ apply boffInt n r

-- | Skip @n@ bits of the input. Fails if less then @n@ bits remain
skip :: Int -> ErrBitGet ()
skip n = readN BLeft (fromIntegral n) (const ())

-- | Return the number of bits remaining to be parsed
remaining :: BitGet Int
remaining = do
  S bytes boff _ <- get
  return $ B.length bytes * 8 - fromIntegral boff

-- | Return the number of bytes parsed
bytesParsed :: BitGet Word64
bytesParsed = gets totalRead

-- | Return the number of bits parsed
bitsParsed :: BitGet Word64
bitsParsed = do
  S _ bits bytes <- get
  return (fromIntegral bytes*8 + fromIntegral bits)

-- | Return true if there are no more bits to parse
isEmpty :: BitGet Bool
isEmpty = liftM B.null $ gets input

getPtr :: Storable a => Int -> ErrBitGet a
getPtr n = do
    (fp, o, _) <- readN BRight (n * 8) BI.toForeignPtr
    return . BI.inlinePerformIO $ withForeignPtr fp $ \p -> peek (castPtr $ p `plusPtr` o)
{-# INLINE getPtr #-}

-- | Get a single bit from the input
getBit :: ErrBitGet Bool
getBit = readN BRight 1 (not . ((==) 0) . B.head)

-- | Get a ByteString with the given number of bits, left aligned.
getLeftByteString :: Int -> ErrBitGet B.ByteString
getLeftByteString n = readN BLeft n id

-- | Get a ByteString with the given number of bits in, right aligned.
getRightByteString :: Int -> ErrBitGet B.ByteString
getRightByteString n = readN BRight n id

-- | Get a ByteString with the given number of bytes in, right aligned.
getRightByteStringBytes :: Int -> ErrBitGet B.ByteString
getRightByteStringBytes = getRightByteString . ((*) 8)

leftPad :: Int -> B.ByteString -> B.ByteString
leftPad len bs = if B.length bs < len then padded else bs where
  padded = (B.pack $ take extraBytes $ repeat 0) `B.append` bs
  extraBytes = len - B.length bs

-- | TODO: maybe rewrite this macro shit to Template Haskell
GETWORDS(ErrBitGet, getRightByteStringBytes)
GETHOSTWORDS(ErrBitGet)

-- | Read a Word8
getAsWord8 :: Int -> ErrBitGet Word8
getAsWord8 n = readN BRight n $ (flip B.index) 0

-- | Read a Word16 in big endian format
getAsWord16 :: Int -> ErrBitGet Word16
getAsWord16 n = do
    s <- readN BRight n id >>= return . leftPad 2
    return $! DECWORD16BE(s)
{-# INLINE getWord16be #-}

-- | Read a Word32 in big endian format
getAsWord32 :: Int -> ErrBitGet Word32
getAsWord32 n = do
    s <- readN BRight n id >>= return . leftPad 4
    return $! DECWORD32BE(s)
{-# INLINE getWord32be #-}

-- | Read a Word64 in big endian format
getAsWord64 :: Int -> ErrBitGet Word64
getAsWord64 n = do
    s <- readN BRight n id >>= return . leftPad 8
    return $! DECWORD64BE(s)
{-# INLINE getWord64be #-}

shiftl_w16 :: Word16 -> Int -> Word16
shiftl_w32 :: Word32 -> Int -> Word32
shiftl_w64 :: Word64 -> Int -> Word64

shiftl_w16 = shiftL
shiftl_w32 = shiftL
shiftl_w64 = shiftL
